"""
Generate publication-quality figures for the followup paper.

Figures:
1. fig1_age_coefficients.png - Implied age-group coefficients from cubic polynomial
2. fig3_rolling_coefficients.png - Rolling-window Z_1, Z_3 coefficients and R-squared
3. fig6_projections.png - Projected demographic pressure 1970-2060
4. fig7_model_comparison.png - Model comparison R-squared across 4 specifications
"""

import pandas as pd
import numpy as np
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from pathlib import Path

TAB_DIR = Path("/mnt/c/demographics_capital_flows/multilateral/followup/output/tables")
FIG_DIR = Path("/mnt/c/demographics_capital_flows/multilateral/followup/paper/figures")
FIG_DIR.mkdir(parents=True, exist_ok=True)

AGE_LABELS = [
    '0-4', '5-9', '10-14', '15-19', '20-24', '25-29', '30-34',
    '35-39', '40-44', '45-49', '50-54', '55-59', '60-64', '65-69',
    '70-74', '75-79', '80+'
]
G = 17


def recover_age_coefficients(gamma_hat, P=3):
    """Recover implied age-group coefficients from polynomial parameters."""
    g = np.arange(1, G + 1)
    gamma_0 = 0
    for p in range(1, P + 1):
        gamma_0 -= gamma_hat[p - 1] * np.sum(g ** p)
    gamma_0 /= G
    alpha = np.full(G, gamma_0)
    for p in range(1, P + 1):
        alpha += gamma_hat[p - 1] * (g ** p)
    return alpha


def recover_age_se(gamma_hat, gamma_se, P=3, n_sim=10000):
    """
    Bootstrap-style SE for age coefficients via Monte Carlo from gamma distribution.
    Uses delta method: sample gamma from N(gamma_hat, diag(gamma_se^2)), compute alpha each time.
    """
    rng = np.random.default_rng(42)
    alphas = np.zeros((n_sim, G))
    for i in range(n_sim):
        gamma_draw = rng.normal(gamma_hat, gamma_se)
        alphas[i] = recover_age_coefficients(gamma_draw, P)
    return alphas.std(axis=0)


# ============================================================
# Figure 1: Age-Group Coefficients
# ============================================================
def make_fig1():
    print("Generating fig1_age_coefficients.png ...")
    # Use baseline model (Model 2: Demo + EBA) coefficients
    reg = pd.read_csv(TAB_DIR / "regression_baseline_demo_plus_eba_140.csv")

    gamma_hat = np.array([
        reg.loc[reg['variable'] == 'Z_1', 'coefficient'].values[0],
        reg.loc[reg['variable'] == 'Z_2', 'coefficient'].values[0],
        reg.loc[reg['variable'] == 'Z_3', 'coefficient'].values[0],
    ])
    gamma_se = np.array([
        reg.loc[reg['variable'] == 'Z_1', 'std_error'].values[0],
        reg.loc[reg['variable'] == 'Z_2', 'std_error'].values[0],
        reg.loc[reg['variable'] == 'Z_3', 'std_error'].values[0],
    ])

    alpha = recover_age_coefficients(gamma_hat)
    alpha_se = recover_age_se(gamma_hat, gamma_se)

    fig, ax = plt.subplots(figsize=(10, 6))
    x = np.arange(G)

    # Color bars: green for positive (surplus pressure), red for negative (deficit pressure)
    colors = ['#2ca02c' if a > 0 else '#d62728' for a in alpha]

    bars = ax.bar(x, alpha, color=colors, alpha=0.8, edgecolor='#333333', linewidth=0.5)

    # 95% CI error bars -- use lighter style, clipped to visible range
    ci_upper = alpha + 1.96 * alpha_se
    ci_lower = alpha - 1.96 * alpha_se
    # Determine visible range: enough to show the bars + modest CI zone
    coef_max = np.max(np.abs(alpha))
    y_lim = coef_max * 1.5
    # Clip CIs to axis limits
    ci_upper_c = np.clip(ci_upper, -y_lim, y_lim)
    ci_lower_c = np.clip(ci_lower, -y_lim, y_lim)
    yerr_lo = alpha - ci_lower_c
    yerr_hi = ci_upper_c - alpha
    ax.errorbar(x, alpha, yerr=[yerr_lo, yerr_hi], fmt='none', color='#555555',
                capsize=3, linewidth=0.8, capthick=0.8, alpha=0.6)

    ax.axhline(y=0, color='black', linewidth=0.8, linestyle='-')
    ax.set_xticks(x)
    ax.set_xticklabels(AGE_LABELS, rotation=45, ha='right', fontsize=9)
    ax.set_xlabel('Age Group', fontsize=12)
    ax.set_ylabel('Coefficient (effect on CA/GDP, pp)', fontsize=12)
    ax.set_title('Implied Age-Group Coefficients from Cubic Polynomial\n(Baseline Model, 140 Countries)',
                 fontsize=13, fontweight='bold')
    ax.set_ylim(-y_lim, y_lim)
    ax.grid(axis='y', alpha=0.3, linestyle='--')
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)

    # Annotate key features
    peak_idx = np.argmax(alpha[:12])  # peak among working-age groups (0-64)
    ax.annotate(f'Peak: {AGE_LABELS[peak_idx]}',
                xy=(peak_idx, alpha[peak_idx]),
                xytext=(peak_idx + 2, alpha[peak_idx] + y_lim * 0.15),
                fontsize=8, color='#2ca02c',
                arrowprops=dict(arrowstyle='->', color='#2ca02c', lw=1))

    # Note about CI clipping
    n_clipped = int(np.sum((ci_upper > y_lim) | (ci_lower < -y_lim)))
    if n_clipped > 0:
        ax.text(0.02, 0.02,
                f'Note: CIs for {n_clipped} age groups extend beyond axis limits',
                transform=ax.transAxes, fontsize=7, color='gray', style='italic')

    plt.tight_layout()
    plt.savefig(FIG_DIR / "fig1_age_coefficients.png", dpi=150, bbox_inches='tight',
                facecolor='white', edgecolor='none')
    plt.close()
    print("  Saved fig1_age_coefficients.png")


# ============================================================
# Figure 3: Rolling-Window Coefficients
# ============================================================
def make_fig3():
    print("Generating fig3_rolling_coefficients.png ...")
    rw = pd.read_csv(TAB_DIR / "rolling_window_demo_signal_expanded_108.csv")

    # Use midpoint of window as x-axis
    rw['mid_year'] = (rw['start_year'] + rw['end_year']) / 2

    fig, axes = plt.subplots(3, 1, figsize=(11, 10), sharex=True)

    event_lines = [
        (2001, 'WTO accession', '#1f77b4'),
        (2008, 'GFC', '#d62728'),
        (2018, 'Tariff onset', '#ff7f0e'),
    ]

    # Panel 1: Z_1 coefficient
    ax = axes[0]
    ax.plot(rw['mid_year'], rw['Z1_coef'], 'b-', linewidth=2, label='$Z_1$ coefficient')
    # Shade significance: color points by p-value
    sig_mask = rw['Z1_pval'] < 0.05
    ax.fill_between(rw['mid_year'], rw['Z1_coef'] - 15, rw['Z1_coef'] + 15,
                     alpha=0.15, color='steelblue')
    # Mark significant windows
    ax.scatter(rw.loc[sig_mask, 'mid_year'], rw.loc[sig_mask, 'Z1_coef'],
               color='blue', s=30, zorder=5, label='p < 0.05')
    ax.scatter(rw.loc[~sig_mask, 'mid_year'], rw.loc[~sig_mask, 'Z1_coef'],
               color='gray', s=20, zorder=4, alpha=0.5, label='p >= 0.05')
    ax.axhline(y=0, color='black', linewidth=0.5, linestyle='-')
    ax.set_ylabel('$Z_1$ Coefficient', fontsize=11)
    ax.legend(fontsize=8, loc='upper left')
    ax.grid(alpha=0.3, linestyle='--')
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)

    for yr, lbl, clr in event_lines:
        ax.axvline(x=yr, color=clr, linewidth=1, linestyle='--', alpha=0.7)

    # Panel 2: Z_3 coefficient
    ax = axes[1]
    ax.plot(rw['mid_year'], rw['Z3_coef'], 'g-', linewidth=2, label='$Z_3$ coefficient')
    sig_mask3 = rw['Z3_pval'] < 0.05
    ax.fill_between(rw['mid_year'], rw['Z3_coef'] - 0.15, rw['Z3_coef'] + 0.15,
                     alpha=0.15, color='green')
    ax.scatter(rw.loc[sig_mask3, 'mid_year'], rw.loc[sig_mask3, 'Z3_coef'],
               color='green', s=30, zorder=5, label='p < 0.05')
    ax.scatter(rw.loc[~sig_mask3, 'mid_year'], rw.loc[~sig_mask3, 'Z3_coef'],
               color='gray', s=20, zorder=4, alpha=0.5, label='p >= 0.05')
    ax.axhline(y=0, color='black', linewidth=0.5, linestyle='-')
    ax.set_ylabel('$Z_3$ Coefficient', fontsize=11)
    ax.legend(fontsize=8, loc='upper left')
    ax.grid(alpha=0.3, linestyle='--')
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)

    for yr, lbl, clr in event_lines:
        ax.axvline(x=yr, color=clr, linewidth=1, linestyle='--', alpha=0.7)

    # Panel 3: R-squared (full model vs controls only)
    ax = axes[2]
    ax.plot(rw['mid_year'], rw['r2_full'], 'k-', linewidth=2, label='Full model $R^2$')
    ax.plot(rw['mid_year'], rw['r2_controls'], 'k--', linewidth=1.5, alpha=0.6,
            label='Controls only $R^2$')
    ax.fill_between(rw['mid_year'], rw['r2_controls'], rw['r2_full'],
                     alpha=0.2, color='purple', label='$\\Delta R^2$ (demographics)')
    ax.set_ylabel('$R^2$', fontsize=11)
    ax.set_xlabel('Window Midpoint Year', fontsize=12)
    ax.legend(fontsize=8, loc='upper left')
    ax.grid(alpha=0.3, linestyle='--')
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)

    for yr, lbl, clr in event_lines:
        ax.axvline(x=yr, color=clr, linewidth=1, linestyle='--', alpha=0.7)
        ax.text(yr, ax.get_ylim()[1] * 0.98, lbl, rotation=90, va='top', ha='right',
                fontsize=7, color=clr)

    fig.suptitle('Rolling 15-Year Window Estimates (Expanded 108-Country Sample)',
                 fontsize=13, fontweight='bold', y=0.98)
    plt.tight_layout(rect=[0, 0, 1, 0.96])
    plt.savefig(FIG_DIR / "fig3_rolling_coefficients.png", dpi=150, bbox_inches='tight',
                facecolor='white', edgecolor='none')
    plt.close()
    print("  Saved fig3_rolling_coefficients.png")


# ============================================================
# Figure 6: Projections
# ============================================================
def make_fig6():
    print("Generating fig6_projections.png ...")
    proj = pd.read_csv(TAB_DIR / "projection_table_140.csv")

    # Reshape from wide to long
    id_col = proj.columns[0]  # 'Country'
    year_cols = [c for c in proj.columns if c != id_col]
    proj_long = proj.melt(id_vars=id_col, value_vars=year_cols,
                          var_name='year', value_name='demo_contribution')
    proj_long['year'] = proj_long['year'].astype(int)
    proj_long.rename(columns={id_col: 'country'}, inplace=True)

    # Also load the full demographic contributions for historical data
    contrib = pd.read_csv(TAB_DIR / "demographic_contributions_140.csv")

    # Selected countries matching paper
    countries = ['CHN', 'IND', 'JPN', 'USA', 'DEU', 'BRA', 'NGA', 'KOR']

    # Use a colorblind-friendly palette
    colors = {
        'CHN': '#d62728',   # red
        'IND': '#ff7f0e',   # orange
        'JPN': '#1f77b4',   # blue
        'USA': '#2ca02c',   # green
        'DEU': '#9467bd',   # purple
        'BRA': '#8c564b',   # brown
        'NGA': '#e377c2',   # pink
        'KOR': '#17becf',   # cyan
    }

    fig, ax = plt.subplots(figsize=(12, 7))

    for country in countries:
        # Historical from contributions file
        cdf = contrib[contrib['iso3'] == country][['year', 'demo_contribution']].copy()
        cdf = cdf[(cdf['year'] >= 1980) & (cdf['year'] <= 2024)]
        cdf = cdf.sort_values('year')

        # Projection points
        pdf = proj_long[proj_long['country'] == country].sort_values('year')
        # Only future projections (2025+)
        pdf_future = pdf[pdf['year'] > 2024]

        # Get 2024 historical value for connection
        if len(cdf) > 0 and len(pdf_future) > 0:
            # Plot historical as solid line
            ax.plot(cdf['year'], cdf['demo_contribution'],
                    color=colors[country], linewidth=2, label=country)
            # Connect with projection using dashed line
            bridge_years = [cdf['year'].iloc[-1]] + pdf_future['year'].tolist()
            bridge_vals = [cdf['demo_contribution'].iloc[-1]] + pdf_future['demo_contribution'].tolist()
            ax.plot(bridge_years, bridge_vals,
                    color=colors[country], linewidth=2, linestyle='--')
        elif len(cdf) > 0:
            ax.plot(cdf['year'], cdf['demo_contribution'],
                    color=colors[country], linewidth=2, label=country)

    ax.axhline(y=0, color='black', linewidth=0.8)
    ax.axvline(x=2024, color='gray', linewidth=1.2, linestyle='--', alpha=0.7)
    ax.text(2024.5, ax.get_ylim()[1] * 0.95, '2024', fontsize=9, color='gray', va='top')

    ax.set_xlabel('Year', fontsize=12)
    ax.set_ylabel('Demographic Contribution to CA/GDP (pp)', fontsize=12)
    ax.set_title('Projected Demographic Pressure on Current Accounts\nSelected Economies, 1980-2060',
                 fontsize=13, fontweight='bold')
    ax.legend(loc='upper left', fontsize=10, ncol=2, framealpha=0.9)
    ax.grid(alpha=0.3, linestyle='--')
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.set_xlim(1980, 2062)

    plt.tight_layout()
    plt.savefig(FIG_DIR / "fig6_projections.png", dpi=150, bbox_inches='tight',
                facecolor='white', edgecolor='none')
    plt.close()
    print("  Saved fig6_projections.png")


# ============================================================
# Figure 7: Model Comparison
# ============================================================
def make_fig7():
    print("Generating fig7_model_comparison.png ...")

    # Build model comparison data from paper table (includes Model 4 pension)
    models = [
        'Model 1:\nDemographics Only',
        'Model 2:\nBaseline (Demo+EBA)',
        'Model 3:\nExtended (+Rates)',
        'Model 4:\nPension Interactions',
    ]
    r2_vals = [0.056, 0.273, 0.290, 0.137]
    n_obs = [5323, 2730, 1626, 750]
    n_countries = [141, 137, 90, 40]

    fig, ax = plt.subplots(figsize=(9, 5.5))

    x = np.arange(len(models))
    bar_colors = ['#1f77b4', '#2ca02c', '#ff7f0e', '#9467bd']

    bars = ax.bar(x, r2_vals, color=bar_colors, alpha=0.85, edgecolor='#333333',
                  linewidth=0.8, width=0.6)

    # Add value labels and observation counts
    for i, (bar, r2, n, nc) in enumerate(zip(bars, r2_vals, n_obs, n_countries)):
        ax.text(bar.get_x() + bar.get_width() / 2, bar.get_height() + 0.008,
                f'$R^2$ = {r2:.3f}', ha='center', va='bottom', fontsize=10, fontweight='bold')
        ax.text(bar.get_x() + bar.get_width() / 2, bar.get_height() / 2,
                f'N = {n:,}\n({nc} countries)', ha='center', va='center',
                fontsize=8, color='white', fontweight='bold')

    ax.set_xticks(x)
    ax.set_xticklabels(models, fontsize=9)
    ax.set_ylabel('$R^2$', fontsize=12)
    ax.set_title('Model Comparison: $R^2$ Across Four Specifications',
                 fontsize=13, fontweight='bold')
    ax.set_ylim(0, max(r2_vals) * 1.25)
    ax.grid(axis='y', alpha=0.3, linestyle='--')
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)

    plt.tight_layout()
    plt.savefig(FIG_DIR / "fig7_model_comparison.png", dpi=150, bbox_inches='tight',
                facecolor='white', edgecolor='none')
    plt.close()
    print("  Saved fig7_model_comparison.png")


if __name__ == '__main__':
    make_fig1()
    make_fig3()
    make_fig6()
    make_fig7()
    print("\nAll figures generated successfully.")
